Add batched QR support to SOAP optimizer#118
Add batched QR support to SOAP optimizer#118RPrenger wants to merge 1 commit intoNVIDIA-NeMo:mainfrom
Conversation
SOAP calls torch.linalg.qr once per Kronecker factor per parameter per
power iteration step. For a model with N 2D parameters that is 2N
sequential QR calls per eigenbasis update, each paying ~7.5 ms of CPU
dispatch overhead.
Add a `use_batched_qr` flag that restructures `SOAP.step()` into three
phases — collect, batch-QR, apply — so every QR across all parameters is
dispatched in a single `batched_qr_grouped()` call using multiple CUDA
streams. When the flag is False (default) the original code path is
unchanged.
Benchmarked on H100 (12-layer GPT-2-style transformer, 96 QR ops/step):
hidden=1024, mlp=4096 : 1225 ms → 348 ms (3.5× speedup)
hidden=2048, mlp=4096 : 2210 ms → 887 ms (2.5× speedup)
Greptile SummaryThis PR adds an optional Key findings:
Confidence Score: 2/5
Last reviewed commit: 3eed6a9 |
| if qr_items: | ||
| torch.cuda.nvtx.range_push("batched_qr") | ||
| for _ in range(self.power_iter_steps): | ||
| # Collect: matmul kronecker_factor @ Q for every work item | ||
| q_matrices = [] | ||
| for item in qr_items: | ||
| with utils.fp32_matmul_precision(self.qr_fp32_matmul_prec): | ||
| q_matrices.append(item["kf_f"] @ item["Q"]) | ||
|
|
||
| # Single batched QR dispatch | ||
| q_results = batched_qr_grouped(q_matrices, num_streams=self.batched_qr_num_streams) | ||
|
|
||
| # Scatter results back | ||
| for item, Q_new in zip(qr_items, q_results): | ||
| item["Q"] = Q_new | ||
| torch.cuda.nvtx.range_pop() | ||
|
|
||
| # ── Finalize eigenbases and project momentum forward ── | ||
| for pw in param_work: | ||
| if not pw["use_qr_batch"]: | ||
| continue | ||
|
|
||
| state = pw["state"] | ||
| updated_eigenbasis_list: list[torch.Tensor] = [] | ||
|
|
||
| for ind, (kf, eigenbasis) in enumerate(zip(state["GG"], state["Q"])): | ||
| fi = pw["factor_items"][ind] | ||
| if fi is None: | ||
| # Empty factor (e.g. 1D param with precondition_1d=False) | ||
| updated_eigenbasis_list.append(torch.empty(0, device=kf.device)) | ||
| elif "skip" in fi: | ||
| # Adaptive criteria met — keep existing eigenbasis | ||
| updated_eigenbasis_list.append(fi["eigenbasis"]) | ||
| else: | ||
| updated_eigenbasis_list.append(fi["Q"]) | ||
|
|
||
| state["Q"] = updated_eigenbasis_list | ||
|
|
||
| # Step 3 of eigenbasis update: project momentum to new eigenbasis | ||
| with utils.fp32_matmul_precision(self.qr_fp32_matmul_prec): | ||
| state["exp_avg"] = precondition( | ||
| state["exp_avg"], | ||
| updated_eigenbasis_list, | ||
| dims=[[0], [0]], | ||
| ) |
There was a problem hiding this comment.
Critical: exp_avg left in original basis when all QR items satisfy adaptive criteria
When use_adaptive_criteria=True, Phase 1 unconditionally projects state["exp_avg"] back to the original basis (via precondition(..., dims=[[0], [1]])) for every parameter with use_qr_batch=True (line 403). The matching forward projection (dims=[[0], [0]]) in Phase 2 finalization (lines 496–502) is nested inside if qr_items:.
If on a given step all factors for all parameters that need QR updates satisfy the adaptive criteria, no items are added to qr_items. The entire Phase 2 block (including finalization) is skipped, and state["exp_avg"] remains in the original (unpreconditioned) basis. In Phase 3, it is then fed directly to calculate_adam_update as if it were in the eigenbasis, producing corrupted parameter updates.
Fix: Move the finalization loop (lines 476–502) outside the if qr_items: guard so that momentum forward-projection always executes for every parameter whose exp_avg was projected back in Phase 1, regardless of whether qr_items is empty:
| if qr_items: | |
| torch.cuda.nvtx.range_push("batched_qr") | |
| for _ in range(self.power_iter_steps): | |
| # Collect: matmul kronecker_factor @ Q for every work item | |
| q_matrices = [] | |
| for item in qr_items: | |
| with utils.fp32_matmul_precision(self.qr_fp32_matmul_prec): | |
| q_matrices.append(item["kf_f"] @ item["Q"]) | |
| # Single batched QR dispatch | |
| q_results = batched_qr_grouped(q_matrices, num_streams=self.batched_qr_num_streams) | |
| # Scatter results back | |
| for item, Q_new in zip(qr_items, q_results): | |
| item["Q"] = Q_new | |
| torch.cuda.nvtx.range_pop() | |
| # ── Finalize eigenbases and project momentum forward ── | |
| for pw in param_work: | |
| if not pw["use_qr_batch"]: | |
| continue | |
| state = pw["state"] | |
| updated_eigenbasis_list: list[torch.Tensor] = [] | |
| for ind, (kf, eigenbasis) in enumerate(zip(state["GG"], state["Q"])): | |
| fi = pw["factor_items"][ind] | |
| if fi is None: | |
| # Empty factor (e.g. 1D param with precondition_1d=False) | |
| updated_eigenbasis_list.append(torch.empty(0, device=kf.device)) | |
| elif "skip" in fi: | |
| # Adaptive criteria met — keep existing eigenbasis | |
| updated_eigenbasis_list.append(fi["eigenbasis"]) | |
| else: | |
| updated_eigenbasis_list.append(fi["Q"]) | |
| state["Q"] = updated_eigenbasis_list | |
| # Step 3 of eigenbasis update: project momentum to new eigenbasis | |
| with utils.fp32_matmul_precision(self.qr_fp32_matmul_prec): | |
| state["exp_avg"] = precondition( | |
| state["exp_avg"], | |
| updated_eigenbasis_list, | |
| dims=[[0], [0]], | |
| ) | |
| # ── Phase 2: batched QR across all parameters ── | |
| if qr_items: | |
| torch.cuda.nvtx.range_push("batched_qr") | |
| for _ in range(self.power_iter_steps): | |
| q_matrices = [] | |
| for item in qr_items: | |
| with utils.fp32_matmul_precision(self.qr_fp32_matmul_prec): | |
| q_matrices.append(item["kf_f"] @ item["Q"]) | |
| q_results = batched_qr_grouped(q_matrices, num_streams=self.batched_qr_num_streams) | |
| for item, Q_new in zip(qr_items, q_results): | |
| item["Q"] = Q_new | |
| torch.cuda.nvtx.range_pop() | |
| # ── Finalize eigenbases and project momentum forward (always, for every param that | |
| # projected exp_avg back in Phase 1) ── | |
| for pw in param_work: | |
| if not pw["use_qr_batch"]: | |
| continue | |
| state = pw["state"] | |
| updated_eigenbasis_list: list[torch.Tensor] = [] | |
| for ind, (kf, eigenbasis) in enumerate(zip(state["GG"], state["Q"])): | |
| fi = pw["factor_items"][ind] | |
| if fi is None: | |
| updated_eigenbasis_list.append(torch.empty(0, device=kf.device)) | |
| elif "skip" in fi: | |
| updated_eigenbasis_list.append(fi["eigenbasis"]) | |
| else: | |
| updated_eigenbasis_list.append(fi["Q"]) | |
| state["Q"] = updated_eigenbasis_list | |
| with utils.fp32_matmul_precision(self.qr_fp32_matmul_prec): | |
| state["exp_avg"] = precondition( | |
| state["exp_avg"], | |
| updated_eigenbasis_list, | |
| dims=[[0], [0]], | |
| ) |
| raise ImportError( | ||
| "batched_qr package is required when use_batched_qr=True. " | ||
| "Install with: pip install -e /path/to/batched_qr_torch" | ||
| ) from e |
There was a problem hiding this comment.
Placeholder install path in error message
The error message contains a literal placeholder pip install -e /path/to/batched_qr_torch that provides no actionable guidance to users. When users encounter this error, they will see the placeholder path and not know where to find the package.
| raise ImportError( | |
| "batched_qr package is required when use_batched_qr=True. " | |
| "Install with: pip install -e /path/to/batched_qr_torch" | |
| ) from e | |
| raise ImportError( | |
| "batched_qr package is required when use_batched_qr=True. " | |
| "See the project README for installation instructions." | |
| ) from e |
Or provide a concrete PyPI package name / repository URL if available.
|
Thanks @RPrenger, this is on my list but never got time to. Couple of questions:
Code itself also needs more work, the change is too intrusive. DCO check failed. instruction is in CONTRINBUTING.md. |
|
Also a heads up, I'm refactoring SOAP before megatron integration, like #117. will need to rebase once refactoring is done. |
SOAP calls torch.linalg.qr once per Kronecker factor per parameter per
power iteration step. For a model with N 2D parameters that is 2N
sequential QR calls per eigenbasis update, each paying ~7.5 ms of CPU
dispatch overhead.
Add a
use_batched_qrflag that restructuresSOAP.step()into threephases — collect, batch-QR, apply — so every QR across all parameters is
dispatched in a single
batched_qr_grouped()call using multiple CUDAstreams. When the flag is False (default) the original code path is
unchanged.
Benchmarked on H100 (12-layer GPT-2-style transformer, 96 QR ops/step):